/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved.
 */
#ifndef LUT16_SSE4_H_
#define LUT16_SSE4_H_

#include <cstdint>
#include <arm_neon.h>
#include <functional>
#include <algorithm>
#include <atomic>
#include <cstdint>
#include <string>

#define SCANN_API_PUBLIC __attribute__((visibility("default")))

namespace hw_alg {

SCANN_API_PUBLIC uint32_t GetComparisonMask(int16x8_t int16_accumsj0, int16x8_t int16_accumsj1,
    int16x8_t int16_accumsj2, int16x8_t int16_accumsj3, int16x8_t simd_thresholdsj);

template <size_t kNumQueries>
SCANN_API_PUBLIC inline __attribute__((always_inline)) void Sse4LUT16BottomLoopExtract(uint16x8x4_t *acc,
    const uint8_t *data_start, std::array<const uint8_t *, kNumQueries> lookup_starts, size_t num_blocks);

template <size_t kNumQueries>
SCANN_API_PUBLIC inline __attribute__((always_inline)) void Sse4LUT16BottomLoopExtractNew(uint16x8x4_t *acc,
    const uint8_t *data_start, std::array<const uint8_t *, kNumQueries> lookup_starts, size_t num_blocks)
{
    constexpr size_t k_unroll_by = 32; // 16 * 2 = 32 32 codes per iteration
    auto mask = vld1q_u8(data_start);
    auto mask_2 = vld1q_u8(data_start + 16);
    data_start += 32; // loads 16 * 2 = 32 packed data at once
    int bitNum = 4;

    uint16x8_t accLocal1[kNumQueries];
    uint16x8_t accLocal2[kNumQueries];
    uint16x8_t accLocal3[kNumQueries];
    uint16x8_t accLocal4[kNumQueries];
    for (int i = 0; i < kNumQueries; i++) {
        accLocal1[i] = vdupq_n_u16(0);
        accLocal2[i] = vdupq_n_u16(0);
        accLocal3[i] = vdupq_n_u16(0);
        accLocal4[i] = vdupq_n_u16(0);
    }
    uint8x16_t mask16 = vreinterpretq_u8_u16(vdupq_n_u16(0x1000));
    uint8x16_t maskone = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
    num_blocks = num_blocks / 2; // 2 blocks per iteration

    const size_t kLineWidth = 16; // register fits 128 / 8 = 16 codes
    for (; num_blocks > 1; --num_blocks, data_start += k_unroll_by) {
        __builtin_prefetch(data_start + 768, 0, 0); // prefetches 32 * 24 = 768 codes ahead
        uint8x16_t mask1_2 = vsriq_n_u8(mask16, mask_2, bitNum);
        uint8x16_t mask1 = vsriq_n_u8(mask16, mask, bitNum);

        mask_2 = vsliq_n_u8(mask_2, maskone, bitNum);
        mask = vsliq_n_u8(mask, maskone, bitNum);
        for (size_t j = 0; j < kNumQueries; j++) {
            uint8x16x2_t dictCombine = vld1q_u8_x2(lookup_starts[j]);
            lookup_starts[j] += k_unroll_by;

            uint8x16_t res0;
            uint8x16_t res1;
            uint8x16_t res0_2;
            uint8x16_t res1_2;
            asm volatile(
                "tbl %[res0].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask].16B" : [res0] "=w"(res0)
                : [dictCombine1] "w"(dictCombine.val[0]), [dictCombine2] "w"(dictCombine.val[1]), [mask] "w"(mask) :);
            asm volatile(
                "tbl %[res1].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1].16B" : [res1] "=w"(res1)
                : [dictCombine1] "w"(dictCombine.val[0]), [dictCombine2] "w"(dictCombine.val[1]), [mask1] "w"(mask1) :);
            asm volatile(
                "tbl %[res0_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask_2].16B"
                : [res0_2] "=w"(res0_2): [dictCombine1] "w"(dictCombine.val[0]),
                [dictCombine2] "w"(dictCombine.val[1]), [mask_2] "w"(mask_2) :);
            asm volatile(
                "tbl %[res1_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1_2].16B"
                : [res1_2] "=w"(res1_2) : [dictCombine1] "w"(dictCombine.val[0]),
                [dictCombine2] "w"(dictCombine.val[1]), [mask1_2] "w"(mask1_2):);

            accLocal1[j] = vpadalq_u8(accLocal1[j], res0);
            accLocal2[j] = vpadalq_u8(accLocal2[j], res1);
            accLocal3[j] = vpadalq_u8(accLocal3[j], res0_2);
            accLocal4[j] = vpadalq_u8(accLocal4[j], res1_2);
        }

        mask = vld1q_u8(data_start);
        mask_2 = vld1q_u8(data_start + kLineWidth);
    }
    for (; num_blocks != 0; --num_blocks, data_start += k_unroll_by) {
        __builtin_prefetch(data_start + 768, 0, 0); // prefetches 32 * 24 = 768 codes ahead
        uint8x16_t mask1_2 = vsriq_n_u8(mask16, mask_2, bitNum);
        uint8x16_t mask1 = vsriq_n_u8(mask16, mask, bitNum);

        mask_2 = vsliq_n_u8(mask_2, maskone, bitNum);
        mask = vsliq_n_u8(mask, maskone, bitNum);
        for (size_t j = 0; j < kNumQueries; j++) {
            uint8x16x2_t dictCombine = vld1q_u8_x2(lookup_starts[j]);
            lookup_starts[j] += k_unroll_by;

            uint8x16_t res0;
            uint8x16_t res1;
            uint8x16_t res0_2;
            uint8x16_t res1_2;
            asm volatile(
                "tbl %[res0].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask].16B" : [res0] "=w"(res0)
                : [dictCombine1] "w"(dictCombine.val[0]), [dictCombine2] "w"(dictCombine.val[1]), [mask] "w"(mask) :);
            asm volatile(
                "tbl %[res1].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1].16B" : [res1] "=w"(res1)
                : [dictCombine1] "w"(dictCombine.val[0]), [dictCombine2] "w"(dictCombine.val[1]), [mask1] "w"(mask1) :);
            asm volatile(
                "tbl %[res0_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask_2].16B"
                : [res0_2] "=w"(res0_2): [dictCombine1] "w"(dictCombine.val[0]),
                [dictCombine2] "w"(dictCombine.val[1]), [mask_2] "w"(mask_2) :);
            asm volatile(
                "tbl %[res1_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1_2].16B"
                : [res1_2] "=w"(res1_2) : [dictCombine1] "w"(dictCombine.val[0]),
                [dictCombine2] "w"(dictCombine.val[1]), [mask1_2] "w"(mask1_2):);

            accLocal1[j] = vpadalq_u8(accLocal1[j], res0);
            accLocal2[j] = vpadalq_u8(accLocal2[j], res1);
            accLocal3[j] = vpadalq_u8(accLocal3[j], res0_2);
            accLocal4[j] = vpadalq_u8(accLocal4[j], res1_2);
        }
    }
    for (int i = 0; i < kNumQueries; i++) {
        acc[i].val[0] = accLocal1[i];
        acc[i].val[1] = accLocal2[i];
        acc[i].val[2] = accLocal3[i];
        acc[i].val[3] = accLocal4[i];
    }
}

template <>
SCANN_API_PUBLIC inline __attribute__((always_inline)) void Sse4LUT16BottomLoopExtractNew<1>(uint16x8x4_t *acc,
    const uint8_t *data_start, std::array<const uint8_t *, 1> lookup_starts, size_t num_blocks)
{
    constexpr size_t k_unroll_by = 32; // 16 * 2 = 32 32 codes per iteration
    auto mask = vld1q_u8(data_start);
    auto mask_2 = vld1q_u8(data_start + 16);
    data_start += 32; // loads 16 * 2 = 32 packed data at once
    int bitNum = 4;

    uint16x8_t accLocal11 = vdupq_n_u16(0);
    uint16x8_t accLocal12 = vdupq_n_u16(0);
    uint16x8_t accLocal13 = vdupq_n_u16(0);
    uint16x8_t accLocal14 = vdupq_n_u16(0);

    uint8x16_t mask16 = vreinterpretq_u8_u16(vdupq_n_u16(0x1000));
    uint8x16_t maskone = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
    num_blocks = num_blocks / 2; // 2 blocks per iteration

    const size_t kLineWidth = 16; // register fits 128 / 8 = 16 codes
    for (; num_blocks > 1; --num_blocks, data_start += k_unroll_by) {
        __builtin_prefetch(data_start + 768, 0, 0); // prefetches 32 * 24 = 768 codes ahead
        uint8x16_t mask1_2 = vsriq_n_u8(mask16, mask_2, bitNum);
        uint8x16_t mask1 = vsriq_n_u8(mask16, mask, bitNum);

        mask_2 = vsliq_n_u8(mask_2, maskone, bitNum);
        mask = vsliq_n_u8(mask, maskone, bitNum);

        uint8x16x2_t dictCombine1 = vld1q_u8_x2(lookup_starts[0]);
        lookup_starts[0] += k_unroll_by;

        uint8x16_t res0;
        uint8x16_t res1;
        uint8x16_t res0_2;
        uint8x16_t res1_2;
        asm volatile(
            "tbl %[res0].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask].16B" : [res0] "=w"(res0)
            : [dictCombine1] "w"(dictCombine1.val[0]), [dictCombine2] "w"(dictCombine1.val[1]), [mask] "w"(mask) :);
        asm volatile(
            "tbl %[res1].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1].16B" : [res1] "=w"(res1)
            : [dictCombine1] "w"(dictCombine1.val[0]), [dictCombine2] "w"(dictCombine1.val[1]), [mask1] "w"(mask1) :);
        asm volatile(
            "tbl %[res0_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask_2].16B"
            : [res0_2] "=w"(res0_2): [dictCombine1] "w"(dictCombine1.val[0]),
            [dictCombine2] "w"(dictCombine1.val[1]), [mask_2] "w"(mask_2) :);
        asm volatile(
            "tbl %[res1_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1_2].16B"
            : [res1_2] "=w"(res1_2) : [dictCombine1] "w"(dictCombine1.val[0]),
            [dictCombine2] "w"(dictCombine1.val[1]), [mask1_2] "w"(mask1_2):);
        
        accLocal11 = vpadalq_u8(accLocal11, res0);
        accLocal12 = vpadalq_u8(accLocal12, res1);
        accLocal13 = vpadalq_u8(accLocal13, res0_2);
        accLocal14 = vpadalq_u8(accLocal14, res1_2);

        mask = vld1q_u8(data_start);
        mask_2 = vld1q_u8(data_start + kLineWidth);
    }
    for (; num_blocks != 0; --num_blocks, data_start += k_unroll_by) {
        __builtin_prefetch(data_start + 768, 0, 0); // prefetches 32 * 24 = 768 codes ahead
        uint8x16_t mask1_2 = vsriq_n_u8(mask16, mask_2, bitNum);
        uint8x16_t mask1 = vsriq_n_u8(mask16, mask, bitNum);

        mask_2 = vsliq_n_u8(mask_2, maskone, bitNum);
        mask = vsliq_n_u8(mask, maskone, bitNum);

        uint8x16x2_t dictCombine1 = vld1q_u8_x2(lookup_starts[0]);
        lookup_starts[0] += k_unroll_by;

        uint8x16_t res0;
        uint8x16_t res1;
        uint8x16_t res0_2;
        uint8x16_t res1_2;
        asm volatile(
            "tbl %[res0].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask].16B" : [res0] "=w"(res0)
            : [dictCombine1] "w"(dictCombine1.val[0]), [dictCombine2] "w"(dictCombine1.val[1]), [mask] "w"(mask) :);
        asm volatile(
            "tbl %[res1].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1].16B" : [res1] "=w"(res1)
            : [dictCombine1] "w"(dictCombine1.val[0]), [dictCombine2] "w"(dictCombine1.val[1]), [mask1] "w"(mask1) :);
        asm volatile(
            "tbl %[res0_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask_2].16B"
            : [res0_2] "=w"(res0_2): [dictCombine1] "w"(dictCombine1.val[0]),
            [dictCombine2] "w"(dictCombine1.val[1]), [mask_2] "w"(mask_2) :);
        asm volatile(
            "tbl %[res1_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1_2].16B"
            : [res1_2] "=w"(res1_2) : [dictCombine1] "w"(dictCombine1.val[0]),
            [dictCombine2] "w"(dictCombine1.val[1]), [mask1_2] "w"(mask1_2):);
        
        accLocal11 = vpadalq_u8(accLocal11, res0);
        accLocal12 = vpadalq_u8(accLocal12, res1);
        accLocal13 = vpadalq_u8(accLocal13, res0_2);
        accLocal14 = vpadalq_u8(accLocal14, res1_2);
    }
    acc[0].val[0] = accLocal11;
    acc[0].val[1] = accLocal12;
    acc[0].val[2] = accLocal13;
    acc[0].val[3] = accLocal14;
}

template <>
SCANN_API_PUBLIC inline __attribute__((always_inline)) void Sse4LUT16BottomLoopExtractNew<2>(uint16x8x4_t *acc,
    const uint8_t *data_start, std::array<const uint8_t *, 2> lookup_starts, size_t num_blocks)
{
    constexpr size_t k_unroll_by = 32; // 16 * 2 = 32 32 codes per iteration
    auto mask = vld1q_u8(data_start);
    auto mask_2 = vld1q_u8(data_start + 16);
    data_start += 32; // loads 16 * 2 = 32 packed data at once
    int bitNum = 4;

    uint16x8_t accLocal11 = vdupq_n_u16(0);
    uint16x8_t accLocal12 = vdupq_n_u16(0);
    uint16x8_t accLocal13 = vdupq_n_u16(0);
    uint16x8_t accLocal14 = vdupq_n_u16(0);
    uint16x8_t accLocal21 = vdupq_n_u16(0);
    uint16x8_t accLocal22 = vdupq_n_u16(0);
    uint16x8_t accLocal23 = vdupq_n_u16(0);
    uint16x8_t accLocal24 = vdupq_n_u16(0);

    uint8x16_t mask16 = vreinterpretq_u8_u16(vdupq_n_u16(0x1000));
    uint8x16_t maskone = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
    num_blocks = num_blocks / 2; // 2 blocks per iteration

    const size_t kLineWidth = 16; // register fits 128 / 8 = 16 codes
    for (; num_blocks > 1; --num_blocks, data_start += k_unroll_by) {
        __builtin_prefetch(data_start + 768, 0, 0); // prefetches 32 * 24 = 768 codes ahead
        uint8x16_t mask1_2 = vsriq_n_u8(mask16, mask_2, bitNum);
        uint8x16_t mask1 = vsriq_n_u8(mask16, mask, bitNum);

        mask_2 = vsliq_n_u8(mask_2, maskone, bitNum);
        mask = vsliq_n_u8(mask, maskone, bitNum);

        uint8x16x2_t dictCombine1 = vld1q_u8_x2(lookup_starts[0]);
        uint8x16x2_t dictCombine2 = vld1q_u8_x2(lookup_starts[1]);
        lookup_starts[0] += k_unroll_by;
        lookup_starts[1] += k_unroll_by;

        uint8x16_t res0;
        uint8x16_t res1;
        uint8x16_t res0_2;
        uint8x16_t res1_2;
        uint8x16_t res2;
        uint8x16_t res3;
        uint8x16_t res2_2;
        uint8x16_t res3_2;
        asm volatile(
            "tbl %[res0].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask].16B" : [res0] "=w"(res0)
            : [dictCombine1] "w"(dictCombine1.val[0]), [dictCombine2] "w"(dictCombine1.val[1]), [mask] "w"(mask) :);
        asm volatile(
            "tbl %[res1].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1].16B" : [res1] "=w"(res1)
            : [dictCombine1] "w"(dictCombine1.val[0]), [dictCombine2] "w"(dictCombine1.val[1]), [mask1] "w"(mask1) :);
        asm volatile(
            "tbl %[res0_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask_2].16B"
            : [res0_2] "=w"(res0_2): [dictCombine1] "w"(dictCombine1.val[0]),
            [dictCombine2] "w"(dictCombine1.val[1]), [mask_2] "w"(mask_2) :);
        asm volatile(
            "tbl %[res1_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1_2].16B"
            : [res1_2] "=w"(res1_2) : [dictCombine1] "w"(dictCombine1.val[0]),
            [dictCombine2] "w"(dictCombine1.val[1]), [mask1_2] "w"(mask1_2):);
        
        asm volatile(
            "tbl %[res2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask].16B" : [res2] "=w"(res2)
            : [dictCombine1] "w"(dictCombine2.val[0]), [dictCombine2] "w"(dictCombine2.val[1]), [mask] "w"(mask) :);
        asm volatile(
            "tbl %[res3].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1].16B" : [res3] "=w"(res3)
            : [dictCombine1] "w"(dictCombine2.val[0]), [dictCombine2] "w"(dictCombine2.val[1]), [mask1] "w"(mask1) :);
        asm volatile(
            "tbl %[res2_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask_2].16B"
            : [res2_2] "=w"(res2_2): [dictCombine1] "w"(dictCombine2.val[0]),
            [dictCombine2] "w"(dictCombine2.val[1]), [mask_2] "w"(mask_2) :);
        asm volatile(
            "tbl %[res3_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1_2].16B"
            : [res3_2] "=w"(res3_2) : [dictCombine1] "w"(dictCombine2.val[0]),
            [dictCombine2] "w"(dictCombine2.val[1]), [mask1_2] "w"(mask1_2):);

        accLocal11 = vpadalq_u8(accLocal11, res0);
        accLocal12 = vpadalq_u8(accLocal12, res1);
        accLocal13 = vpadalq_u8(accLocal13, res0_2);
        accLocal14 = vpadalq_u8(accLocal14, res1_2);

        accLocal21 = vpadalq_u8(accLocal21, res2);
        accLocal22 = vpadalq_u8(accLocal22, res3);
        accLocal23 = vpadalq_u8(accLocal23, res2_2);
        accLocal24 = vpadalq_u8(accLocal24, res3_2);

        mask = vld1q_u8(data_start);
        mask_2 = vld1q_u8(data_start + kLineWidth);
    }
    for (; num_blocks != 0; --num_blocks, data_start += k_unroll_by) {
        __builtin_prefetch(data_start + 768, 0, 0); // prefetches 32 * 24 = 768 codes ahead
        uint8x16_t mask1_2 = vsriq_n_u8(mask16, mask_2, bitNum);
        uint8x16_t mask1 = vsriq_n_u8(mask16, mask, bitNum);

        mask_2 = vsliq_n_u8(mask_2, maskone, bitNum);
        mask = vsliq_n_u8(mask, maskone, bitNum);

        uint8x16x2_t dictCombine1 = vld1q_u8_x2(lookup_starts[0]);
        uint8x16x2_t dictCombine2 = vld1q_u8_x2(lookup_starts[1]);
        lookup_starts[0] += k_unroll_by;
        lookup_starts[1] += k_unroll_by;

        uint8x16_t res0;
        uint8x16_t res1;
        uint8x16_t res0_2;
        uint8x16_t res1_2;
        uint8x16_t res2;
        uint8x16_t res3;
        uint8x16_t res2_2;
        uint8x16_t res3_2;
        asm volatile(
            "tbl %[res0].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask].16B" : [res0] "=w"(res0)
            : [dictCombine1] "w"(dictCombine1.val[0]), [dictCombine2] "w"(dictCombine1.val[1]), [mask] "w"(mask) :);
        asm volatile(
            "tbl %[res1].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1].16B" : [res1] "=w"(res1)
            : [dictCombine1] "w"(dictCombine1.val[0]), [dictCombine2] "w"(dictCombine1.val[1]), [mask1] "w"(mask1) :);
        asm volatile(
            "tbl %[res0_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask_2].16B"
            : [res0_2] "=w"(res0_2): [dictCombine1] "w"(dictCombine1.val[0]),
            [dictCombine2] "w"(dictCombine1.val[1]), [mask_2] "w"(mask_2) :);
        asm volatile(
            "tbl %[res1_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1_2].16B"
            : [res1_2] "=w"(res1_2) : [dictCombine1] "w"(dictCombine1.val[0]),
            [dictCombine2] "w"(dictCombine1.val[1]), [mask1_2] "w"(mask1_2):);
        
        asm volatile(
            "tbl %[res2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask].16B" : [res2] "=w"(res2)
            : [dictCombine1] "w"(dictCombine2.val[0]), [dictCombine2] "w"(dictCombine2.val[1]), [mask] "w"(mask) :);
        asm volatile(
            "tbl %[res3].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1].16B" : [res3] "=w"(res3)
            : [dictCombine1] "w"(dictCombine2.val[0]), [dictCombine2] "w"(dictCombine2.val[1]), [mask1] "w"(mask1) :);
        asm volatile(
            "tbl %[res2_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask_2].16B"
            : [res2_2] "=w"(res2_2): [dictCombine1] "w"(dictCombine2.val[0]),
            [dictCombine2] "w"(dictCombine2.val[1]), [mask_2] "w"(mask_2) :);
        asm volatile(
            "tbl %[res3_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1_2].16B"
            : [res3_2] "=w"(res3_2) : [dictCombine1] "w"(dictCombine2.val[0]),
            [dictCombine2] "w"(dictCombine2.val[1]), [mask1_2] "w"(mask1_2):);

        accLocal11 = vpadalq_u8(accLocal11, res0);
        accLocal12 = vpadalq_u8(accLocal12, res1);
        accLocal13 = vpadalq_u8(accLocal13, res0_2);
        accLocal14 = vpadalq_u8(accLocal14, res1_2);

        accLocal21 = vpadalq_u8(accLocal21, res2);
        accLocal22 = vpadalq_u8(accLocal22, res3);
        accLocal23 = vpadalq_u8(accLocal23, res2_2);
        accLocal24 = vpadalq_u8(accLocal24, res3_2);
    }
    acc[0].val[0] = accLocal11;
    acc[0].val[1] = accLocal12;
    acc[0].val[2] = accLocal13;
    acc[0].val[3] = accLocal14;
    acc[1].val[0] = accLocal21;
    acc[1].val[1] = accLocal22;
    acc[1].val[2] = accLocal23;
    acc[1].val[3] = accLocal24;
}

template <>
SCANN_API_PUBLIC inline __attribute__((always_inline)) void Sse4LUT16BottomLoopExtractNew<3>(uint16x8x4_t *acc,
    const uint8_t *data_start, std::array<const uint8_t *, 3> lookup_starts, size_t num_blocks)
{
    constexpr size_t k_unroll_by = 32; // 16 * 2 = 32 32 codes per iteration
    auto mask = vld1q_u8(data_start);
    auto mask_2 = vld1q_u8(data_start + 16);
    data_start += 32; // loads 16 * 2 = 32 packed data at once
    int bitNum = 4;

    uint16x8_t accLocal11 = vdupq_n_u16(0);
    uint16x8_t accLocal12 = vdupq_n_u16(0);
    uint16x8_t accLocal13 = vdupq_n_u16(0);
    uint16x8_t accLocal14 = vdupq_n_u16(0);
    uint16x8_t accLocal21 = vdupq_n_u16(0);
    uint16x8_t accLocal22 = vdupq_n_u16(0);
    uint16x8_t accLocal23 = vdupq_n_u16(0);
    uint16x8_t accLocal24 = vdupq_n_u16(0);
    uint16x8_t accLocal31 = vdupq_n_u16(0);
    uint16x8_t accLocal32 = vdupq_n_u16(0);
    uint16x8_t accLocal33 = vdupq_n_u16(0);
    uint16x8_t accLocal34 = vdupq_n_u16(0);

    uint8x16_t mask16 = vreinterpretq_u8_u16(vdupq_n_u16(0x1000));
    uint8x16_t maskone = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
    num_blocks = num_blocks / 2; // 2 blocks per iteration

    const size_t kLineWidth = 16; // register fits 128 / 8 = 16 codes
    for (; num_blocks > 1; --num_blocks, data_start += k_unroll_by) {
        __builtin_prefetch(data_start + 768, 0, 0); // prefetches 32 * 24 = 768 codes ahead
        uint8x16_t mask1_2 = vsriq_n_u8(mask16, mask_2, bitNum);
        uint8x16_t mask1 = vsriq_n_u8(mask16, mask, bitNum);

        mask_2 = vsliq_n_u8(mask_2, maskone, bitNum);
        mask = vsliq_n_u8(mask, maskone, bitNum);

        uint8x16x2_t dictCombine1 = vld1q_u8_x2(lookup_starts[0]);
        uint8x16x2_t dictCombine2 = vld1q_u8_x2(lookup_starts[1]);
        uint8x16x2_t dictCombine3 = vld1q_u8_x2(lookup_starts[2]);
        lookup_starts[0] += k_unroll_by;
        lookup_starts[1] += k_unroll_by;
        lookup_starts[2] += k_unroll_by;

        uint8x16_t res0;
        uint8x16_t res1;
        uint8x16_t res0_2;
        uint8x16_t res1_2;
        uint8x16_t res2;
        uint8x16_t res3;
        uint8x16_t res2_2;
        uint8x16_t res3_2;
        uint8x16_t res4;
        uint8x16_t res5;
        uint8x16_t res4_2;
        uint8x16_t res5_2;
        asm volatile(
            "tbl %[res0].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask].16B" : [res0] "=w"(res0)
            : [dictCombine1] "w"(dictCombine1.val[0]), [dictCombine2] "w"(dictCombine1.val[1]), [mask] "w"(mask) :);
        asm volatile(
            "tbl %[res1].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1].16B" : [res1] "=w"(res1)
            : [dictCombine1] "w"(dictCombine1.val[0]), [dictCombine2] "w"(dictCombine1.val[1]), [mask1] "w"(mask1) :);
        asm volatile(
            "tbl %[res0_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask_2].16B"
            : [res0_2] "=w"(res0_2): [dictCombine1] "w"(dictCombine1.val[0]),
            [dictCombine2] "w"(dictCombine1.val[1]), [mask_2] "w"(mask_2) :);
        asm volatile(
            "tbl %[res1_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1_2].16B"
            : [res1_2] "=w"(res1_2) : [dictCombine1] "w"(dictCombine1.val[0]),
            [dictCombine2] "w"(dictCombine1.val[1]), [mask1_2] "w"(mask1_2):);
        
        asm volatile(
            "tbl %[res2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask].16B" : [res2] "=w"(res2)
            : [dictCombine1] "w"(dictCombine2.val[0]), [dictCombine2] "w"(dictCombine2.val[1]), [mask] "w"(mask) :);
        asm volatile(
            "tbl %[res3].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1].16B" : [res3] "=w"(res3)
            : [dictCombine1] "w"(dictCombine2.val[0]), [dictCombine2] "w"(dictCombine2.val[1]), [mask1] "w"(mask1) :);
        asm volatile(
            "tbl %[res2_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask_2].16B"
            : [res2_2] "=w"(res2_2): [dictCombine1] "w"(dictCombine2.val[0]),
            [dictCombine2] "w"(dictCombine2.val[1]), [mask_2] "w"(mask_2) :);
        asm volatile(
            "tbl %[res3_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1_2].16B"
            : [res3_2] "=w"(res3_2) : [dictCombine1] "w"(dictCombine2.val[0]),
            [dictCombine2] "w"(dictCombine2.val[1]), [mask1_2] "w"(mask1_2):);

        asm volatile(
            "tbl %[res4].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask].16B" : [res4] "=w"(res4)
            : [dictCombine1] "w"(dictCombine3.val[0]), [dictCombine2] "w"(dictCombine3.val[1]), [mask] "w"(mask) :);
        asm volatile(
            "tbl %[res5].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1].16B" : [res5] "=w"(res5)
            : [dictCombine1] "w"(dictCombine3.val[0]), [dictCombine2] "w"(dictCombine3.val[1]), [mask1] "w"(mask1) :);
        asm volatile(
            "tbl %[res4_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask_2].16B"
            : [res4_2] "=w"(res4_2): [dictCombine1] "w"(dictCombine3.val[0]),
            [dictCombine2] "w"(dictCombine3.val[1]), [mask_2] "w"(mask_2) :);
        asm volatile(
            "tbl %[res5_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1_2].16B"
            : [res5_2] "=w"(res5_2) : [dictCombine1] "w"(dictCombine3.val[0]),
            [dictCombine2] "w"(dictCombine3.val[1]), [mask1_2] "w"(mask1_2):);

        accLocal11 = vpadalq_u8(accLocal11, res0);
        accLocal12 = vpadalq_u8(accLocal12, res1);
        accLocal13 = vpadalq_u8(accLocal13, res0_2);
        accLocal14 = vpadalq_u8(accLocal14, res1_2);

        accLocal21 = vpadalq_u8(accLocal21, res2);
        accLocal22 = vpadalq_u8(accLocal22, res3);
        accLocal23 = vpadalq_u8(accLocal23, res2_2);
        accLocal24 = vpadalq_u8(accLocal24, res3_2);

        accLocal31 = vpadalq_u8(accLocal31, res4);
        accLocal32 = vpadalq_u8(accLocal32, res5);
        accLocal33 = vpadalq_u8(accLocal33, res4_2);
        accLocal34 = vpadalq_u8(accLocal34, res5_2);

        mask = vld1q_u8(data_start);
        mask_2 = vld1q_u8(data_start + kLineWidth);
    }
    for (; num_blocks != 0; --num_blocks, data_start += k_unroll_by) {
        __builtin_prefetch(data_start + 768, 0, 0); // prefetches 32 * 24 = 768 codes ahead
        uint8x16_t mask1_2 = vsriq_n_u8(mask16, mask_2, bitNum);
        uint8x16_t mask1 = vsriq_n_u8(mask16, mask, bitNum);

        mask_2 = vsliq_n_u8(mask_2, maskone, bitNum);
        mask = vsliq_n_u8(mask, maskone, bitNum);

        uint8x16x2_t dictCombine1 = vld1q_u8_x2(lookup_starts[0]);
        uint8x16x2_t dictCombine2 = vld1q_u8_x2(lookup_starts[1]);
        uint8x16x2_t dictCombine3 = vld1q_u8_x2(lookup_starts[2]);
        lookup_starts[0] += k_unroll_by;
        lookup_starts[1] += k_unroll_by;
        lookup_starts[2] += k_unroll_by;

        uint8x16_t res0;
        uint8x16_t res1;
        uint8x16_t res0_2;
        uint8x16_t res1_2;
        uint8x16_t res2;
        uint8x16_t res3;
        uint8x16_t res2_2;
        uint8x16_t res3_2;
        uint8x16_t res4;
        uint8x16_t res5;
        uint8x16_t res4_2;
        uint8x16_t res5_2;
        asm volatile(
            "tbl %[res0].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask].16B" : [res0] "=w"(res0)
            : [dictCombine1] "w"(dictCombine1.val[0]), [dictCombine2] "w"(dictCombine1.val[1]), [mask] "w"(mask) :);
        asm volatile(
            "tbl %[res1].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1].16B" : [res1] "=w"(res1)
            : [dictCombine1] "w"(dictCombine1.val[0]), [dictCombine2] "w"(dictCombine1.val[1]), [mask1] "w"(mask1) :);
        asm volatile(
            "tbl %[res0_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask_2].16B"
            : [res0_2] "=w"(res0_2): [dictCombine1] "w"(dictCombine1.val[0]),
            [dictCombine2] "w"(dictCombine1.val[1]), [mask_2] "w"(mask_2) :);
        asm volatile(
            "tbl %[res1_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1_2].16B"
            : [res1_2] "=w"(res1_2) : [dictCombine1] "w"(dictCombine1.val[0]),
            [dictCombine2] "w"(dictCombine1.val[1]), [mask1_2] "w"(mask1_2):);
        
        asm volatile(
            "tbl %[res2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask].16B" : [res2] "=w"(res2)
            : [dictCombine1] "w"(dictCombine2.val[0]), [dictCombine2] "w"(dictCombine2.val[1]), [mask] "w"(mask) :);
        asm volatile(
            "tbl %[res3].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1].16B" : [res3] "=w"(res3)
            : [dictCombine1] "w"(dictCombine2.val[0]), [dictCombine2] "w"(dictCombine2.val[1]), [mask1] "w"(mask1) :);
        asm volatile(
            "tbl %[res2_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask_2].16B"
            : [res2_2] "=w"(res2_2): [dictCombine1] "w"(dictCombine2.val[0]),
            [dictCombine2] "w"(dictCombine2.val[1]), [mask_2] "w"(mask_2) :);
        asm volatile(
            "tbl %[res3_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1_2].16B"
            : [res3_2] "=w"(res3_2) : [dictCombine1] "w"(dictCombine2.val[0]),
            [dictCombine2] "w"(dictCombine2.val[1]), [mask1_2] "w"(mask1_2):);

        asm volatile(
            "tbl %[res4].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask].16B" : [res4] "=w"(res4)
            : [dictCombine1] "w"(dictCombine3.val[0]), [dictCombine2] "w"(dictCombine3.val[1]), [mask] "w"(mask) :);
        asm volatile(
            "tbl %[res5].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1].16B" : [res5] "=w"(res5)
            : [dictCombine1] "w"(dictCombine3.val[0]), [dictCombine2] "w"(dictCombine3.val[1]), [mask1] "w"(mask1) :);
        asm volatile(
            "tbl %[res4_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask_2].16B"
            : [res4_2] "=w"(res4_2): [dictCombine1] "w"(dictCombine3.val[0]),
            [dictCombine2] "w"(dictCombine3.val[1]), [mask_2] "w"(mask_2) :);
        asm volatile(
            "tbl %[res5_2].16B, { %[dictCombine1].16B - %[dictCombine2].16B }, %[mask1_2].16B"
            : [res5_2] "=w"(res5_2) : [dictCombine1] "w"(dictCombine3.val[0]),
            [dictCombine2] "w"(dictCombine3.val[1]), [mask1_2] "w"(mask1_2):);

        accLocal11 = vpadalq_u8(accLocal11, res0);
        accLocal12 = vpadalq_u8(accLocal12, res1);
        accLocal13 = vpadalq_u8(accLocal13, res0_2);
        accLocal14 = vpadalq_u8(accLocal14, res1_2);

        accLocal21 = vpadalq_u8(accLocal21, res2);
        accLocal22 = vpadalq_u8(accLocal22, res3);
        accLocal23 = vpadalq_u8(accLocal23, res2_2);
        accLocal24 = vpadalq_u8(accLocal24, res3_2);

        accLocal31 = vpadalq_u8(accLocal31, res4);
        accLocal32 = vpadalq_u8(accLocal32, res5);
        accLocal33 = vpadalq_u8(accLocal33, res4_2);
        accLocal34 = vpadalq_u8(accLocal34, res5_2);
    }
    acc[0].val[0] = accLocal11;
    acc[0].val[1] = accLocal12;
    acc[0].val[2] = accLocal13;
    acc[0].val[3] = accLocal14;
    acc[1].val[0] = accLocal21;
    acc[1].val[1] = accLocal22;
    acc[1].val[2] = accLocal23;
    acc[1].val[3] = accLocal24;
    acc[2].val[0] = accLocal31;
    acc[2].val[1] = accLocal32;
    acc[2].val[2] = accLocal33;
    acc[2].val[3] = accLocal34;
}

template <typename T>
SCANN_API_PUBLIC inline __attribute__((always_inline)) double DenseSquaredL2DistanceBatched(const T* aptr,
    const T* bptr, const size_t num_nonzero);

}  // namespace hw_alg

#endif
